{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Copyright 2015 Google Inc. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "\n",
    "\"\"\"Functions for downloading and reading MNIST data.\"\"\"\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import gzip\n",
    "import os\n",
    "\n",
    "import numpy\n",
    "from six.moves import urllib\n",
    "from six.moves import xrange  # pylint: disable=redefined-builtin\n",
    "import tensorflow as tf\n",
    "\n",
    "SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'\n",
    "\n",
    "def maybe_download(filename, work_directory):\n",
    "  \"\"\"Download the data from Yann's website, unless it's already here.\"\"\"\n",
    "  if not tf.gfile.Exists(work_directory):\n",
    "    tf.gfile.MakeDirs(work_directory)\n",
    "  filepath = os.path.join(work_directory, filename)\n",
    "  if not tf.gfile.Exists(filepath):\n",
    "    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)\n",
    "    with tf.gfile.GFile(filepath) as f:\n",
    "      size = f.Size()\n",
    "    print('Successfully downloaded', filename, size, 'bytes.')\n",
    "  return filepath\n",
    "\n",
    "\n",
    "def _read32(bytestream):\n",
    "  dt = numpy.dtype(numpy.uint32).newbyteorder('>')\n",
    "  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]\n",
    "\n",
    "\n",
    "def extract_images(filename):\n",
    "  \"\"\"Extract the images into a 4D uint8 numpy array [index, y, x, depth].\"\"\"\n",
    "  print('Extracting', filename)\n",
    "  with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:\n",
    "    magic = _read32(bytestream)\n",
    "    if magic != 2051:\n",
    "      raise ValueError(\n",
    "          'Invalid magic number %d in MNIST image file: %s' %\n",
    "          (magic, filename))\n",
    "    num_images = _read32(bytestream)\n",
    "    rows = _read32(bytestream)\n",
    "    cols = _read32(bytestream)\n",
    "    buf = bytestream.read(rows * cols * num_images)\n",
    "    data = numpy.frombuffer(buf, dtype=numpy.uint8)\n",
    "    data = data.reshape(num_images, rows, cols, 1)\n",
    "    return data\n",
    "\n",
    "\n",
    "def dense_to_one_hot(labels_dense, num_classes=10):\n",
    "  \"\"\"Convert class labels from scalars to one-hot vectors.\"\"\"\n",
    "  num_labels = labels_dense.shape[0]\n",
    "  index_offset = numpy.arange(num_labels) * num_classes\n",
    "  labels_one_hot = numpy.zeros((num_labels, num_classes))\n",
    "  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1\n",
    "  return labels_one_hot\n",
    "\n",
    "\n",
    "def extract_labels(filename, one_hot=False):\n",
    "  \"\"\"Extract the labels into a 1D uint8 numpy array [index].\"\"\"\n",
    "  print('Extracting', filename)\n",
    "  with tf.gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:\n",
    "    magic = _read32(bytestream)\n",
    "    if magic != 2049:\n",
    "      raise ValueError(\n",
    "          'Invalid magic number %d in MNIST label file: %s' %\n",
    "          (magic, filename))\n",
    "    num_items = _read32(bytestream)\n",
    "    buf = bytestream.read(num_items)\n",
    "    labels = numpy.frombuffer(buf, dtype=numpy.uint8)\n",
    "    if one_hot:\n",
    "      return dense_to_one_hot(labels)\n",
    "    return labels\n",
    "\n",
    "\n",
    "class DataSet(object):\n",
    "\n",
    "  def __init__(self, images, labels, fake_data=False, one_hot=False,\n",
    "               dtype=tf.float32):\n",
    "    \"\"\"Construct a DataSet.\n",
    "    one_hot arg is used only if fake_data is true.  `dtype` can be either\n",
    "    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into\n",
    "    `[0, 1]`.\n",
    "    \"\"\"\n",
    "    dtype = tf.as_dtype(dtype).base_dtype\n",
    "    if dtype not in (tf.uint8, tf.float32):\n",
    "      raise TypeError('Invalid image dtype %r, expected uint8 or float32' %\n",
    "                      dtype)\n",
    "    if fake_data:\n",
    "      self._num_examples = 10000\n",
    "      self.one_hot = one_hot\n",
    "    else:\n",
    "      assert images.shape[0] == labels.shape[0], (\n",
    "          'images.shape: %s labels.shape: %s' % (images.shape,\n",
    "                                                 labels.shape))\n",
    "      self._num_examples = images.shape[0]\n",
    "\n",
    "      # Convert shape from [num examples, rows, columns, depth]\n",
    "      # to [num examples, rows*columns] (assuming depth == 1)\n",
    "      assert images.shape[3] == 1\n",
    "      images = images.reshape(images.shape[0],\n",
    "                              images.shape[1] * images.shape[2])\n",
    "      if dtype == tf.float32:\n",
    "        # Convert from [0, 255] -> [0.0, 1.0].\n",
    "        images = images.astype(numpy.float32)\n",
    "        images = numpy.multiply(images, 1.0 / 255.0)\n",
    "    self._images = images\n",
    "    self._labels = labels\n",
    "    self._epochs_completed = 0\n",
    "    self._index_in_epoch = 0\n",
    "\n",
    "  @property\n",
    "  def images(self):\n",
    "    return self._images\n",
    "\n",
    "  @property\n",
    "  def labels(self):\n",
    "    return self._labels\n",
    "\n",
    "  @property\n",
    "  def num_examples(self):\n",
    "    return self._num_examples\n",
    "\n",
    "  @property\n",
    "  def epochs_completed(self):\n",
    "    return self._epochs_completed\n",
    "\n",
    "  def next_batch(self, batch_size, fake_data=False):\n",
    "    \"\"\"Return the next `batch_size` examples from this data set.\"\"\"\n",
    "    if fake_data:\n",
    "      fake_image = [1] * 784\n",
    "      if self.one_hot:\n",
    "        fake_label = [1] + [0] * 9\n",
    "      else:\n",
    "        fake_label = 0\n",
    "      return [fake_image for _ in xrange(batch_size)], [\n",
    "          fake_label for _ in xrange(batch_size)]\n",
    "    start = self._index_in_epoch\n",
    "    self._index_in_epoch += batch_size\n",
    "    if self._index_in_epoch > self._num_examples:\n",
    "      # Finished epoch\n",
    "      self._epochs_completed += 1\n",
    "      # Shuffle the data\n",
    "      perm = numpy.arange(self._num_examples)\n",
    "      numpy.random.shuffle(perm)\n",
    "      self._images = self._images[perm]\n",
    "      self._labels = self._labels[perm]\n",
    "      # Start next epoch\n",
    "      start = 0\n",
    "      self._index_in_epoch = batch_size\n",
    "      assert batch_size <= self._num_examples\n",
    "    end = self._index_in_epoch\n",
    "    return self._images[start:end], self._labels[start:end]\n",
    "\n",
    "\n",
    "def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):\n",
    "  class DataSets(object):\n",
    "    pass\n",
    "  data_sets = DataSets()\n",
    "\n",
    "  if fake_data:\n",
    "    def fake():\n",
    "      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)\n",
    "    data_sets.train = fake()\n",
    "    data_sets.validation = fake()\n",
    "    data_sets.test = fake()\n",
    "    return data_sets\n",
    "\n",
    "  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'\n",
    "  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'\n",
    "  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'\n",
    "  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'\n",
    "  VALIDATION_SIZE = 5000\n",
    "\n",
    "  local_file = maybe_download(TRAIN_IMAGES, train_dir)\n",
    "  train_images = extract_images(local_file)\n",
    "\n",
    "  local_file = maybe_download(TRAIN_LABELS, train_dir)\n",
    "  train_labels = extract_labels(local_file, one_hot=one_hot)\n",
    "\n",
    "  local_file = maybe_download(TEST_IMAGES, train_dir)\n",
    "  test_images = extract_images(local_file)\n",
    "\n",
    "  local_file = maybe_download(TEST_LABELS, train_dir)\n",
    "  test_labels = extract_labels(local_file, one_hot=one_hot)\n",
    "\n",
    "  validation_images = train_images[:VALIDATION_SIZE]\n",
    "  validation_labels = train_labels[:VALIDATION_SIZE]\n",
    "  train_images = train_images[VALIDATION_SIZE:]\n",
    "  train_labels = train_labels[VALIDATION_SIZE:]\n",
    "\n",
    "  data_sets.train = DataSet(train_images, train_labels, dtype=dtype)\n",
    "  data_sets.validation = DataSet(validation_images, validation_labels,\n",
    "                                 dtype=dtype)\n",
    "  data_sets.test = DataSet(test_images, test_labels, dtype=dtype)\n",
    "\n",
    "  return data_sets"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
